import logging
import math
from itertools import product, islice, cycle
from contextlib import contextmanager
from collections import deque
from typing import Callable
from dataclasses import dataclass, field

import numpy as np
import torch
from torch.optim.lr_scheduler import ReduceLROnPlateau
from scipy.optimize import minimize

from ..utils import SingularGramError, DebugSet
from .gp import MeasureGoal, MeasureMeta


def logspace(a, b, size, base):
    return ((base ** torch.linspace(1, 0, 100) - 1) / (base - 1)) * (a - b) + b


def shape_step(index, shape, step=1):
    """Add a value to the flat representation of an index. Used to compute the index for the next direction in
    sequential optimization."""
    flat = (np.ravel_multi_index(index, shape) + step) % np.prod(shape)
    return np.unravel_index(flat, shape)


class BayesianOptimization:
    def __init__(self, model, optimizer, sampler, error_fn=None, cheat=False, var=None):
        self.model = model
        self.optimizer = optimizer
        self.sampler = sampler
        self.error_fn = error_fn
        self.optim_state = self.optimizer.init_state(self.model)
        self.var = var
        # cheating!
        if cheat:
            logging.warning('Cheating is allowed, uh-oh!')
            self.optim_state['true_fn'] = sampler.true_energy

    def step(self, step):
        self.optim_state['step'] = step
        x_best = self.optimizer(self.model, self.optim_state)
        n_readout = torch.tensor(self.optim_state['n_readout'])

        if n_readout.dim() !=0 and len(n_readout[torch.nonzero(n_readout)]) == 0:
            if self.error_fn is not None:
                return self.error_fn(self.model, self.optim_state)

        if n_readout.dim() == 0:
            n_readout = n_readout.expand(len(x_best))

        if len(x_best) != len(torch.nonzero(n_readout)):
            raise RuntimeError('Length of numbers to measure does not match number of readouts!')

        if self.optim_state.get('grad'):
            true_fn = self.sampler.true_grad
        else:
            true_fn = self.sampler.true_energy

        stats = [
            true_fn(x_b[None], n_r, var=self.var)
            for x_b, n_r in zip(x_best, torch.masked_select(n_readout, n_readout != 0))
        ]

        if self.var == 'none' or self.optim_state.get('grad'):
            y_best = torch.cat(stats)
            y_var = None
        else:
            y_best, y_var = [torch.cat(elem) for elem in zip(*stats)]

        dim_scale = 1
        if self.optim_state.get('grad'):
            dim_scale = self.sampler.n_angles * 2

        self.optim_state['n_qc_readout'] += int(sum(n_readout)) * dim_scale
        self.optim_state['n_qc_eval'] += len(y_best) * dim_scale

        try:
            self.optimizer.update(self.model, self.optim_state, x_best, y_best, y_var)
        except SingularGramError:
            logging.info('Updated Gram matrix is non-pd! No points added!')

        if self.error_fn is not None:
            return self.error_fn(self.model, self.optim_state)


class AcquisitionFunction:
    def __call__(self, model, x_cand, step):
        pass


@dataclass
class Optimizer:
    _measure_goals = (MeasureGoal.PLAIN,)

    acquisition_fn: AcquisitionFunction = None
    sampler: Callable = None
    n_readout: int = None
    debug: DebugSet = field(default_factory=DebugSet)

    def init_state(self, model):
        min_energy_ind = model.y_train.argmin()
        return {
            'step': 0,
            'n_qc_eval': len(model),
            'n_qc_readout': sum(elem.readout for elem in model.meta if elem.readout is not None),
            'n_readout': self.n_readout,
            'x_start': model.x_train[min_energy_ind],
            'x_best': model.x_train[min_energy_ind],
            'y_start': model.y_train[min_energy_ind],
            'y_start_std': model.y_var[min_energy_ind],
            'y_best': model.y_train[min_energy_ind],
            'goal': self._measure_goals,
            'grad': False,
            'log': {},
        }

    def update(self, model, state, x_measured, y_measured, y_var=None):
        try:
            model.update(
                x_measured,
                y_measured,
                y_var,
                meta=[
                    MeasureMeta(step=state['step'], goal=goal, readout=readout)
                    for goal, readout in zip(state['goal'], state['n_readout'])
                ]
            )
        finally:
            state['x_start'] = x_measured.squeeze(0)
            state['y_start'] = y_measured
            if state['y_start'].item() < state['y_best'].item():
                state['y_best'] = state['y_start']
                state['x_best'] = state['x_start']

    def __call__(self, model, state):
        pass


@dataclass
class SGDOptimizer(Optimizer):
    lr: float = 1e-2
    momentum: float = 0.9
    gdoptim: str = 'sgd'

    def _optimizer(self, params):
        if self.gdoptim == 'sgd':
            return torch.optim.SGD(params, lr=self.lr, momentum=self.momentum)
        if self.gdoptim == 'adam':
            return torch.optim.Adam(params, lr=self.lr)
        raise RuntimeError(f'No such SGD optimizer: {self.gdoptim}')

    def init_state(self, model):
        state = super().init_state(model)
        state['x_param'] = torch.nn.Parameter(state['x_start'].clone())
        state['optim'] = self._optimizer([state['x_param']])
        state['grad'] = True
        return state

    def update(self, model, state, x_measured, y_measured, y_var=None):
        if len(x_measured) != 1:
            raise RuntimeError('SGD Optimizer expects exactly one measured point!')
        if not torch.allclose(state['x_start'], x_measured):
            raise RuntimeError('SGD Optimizer expects the measured point to match the pivot!')
        if not x_measured.shape == y_measured.shape:
            raise RuntimeError('SGD Optimizer expects y_measured to be the gradient, but the shapes did not match!')

        state['x_start'] = x_measured.squeeze(0)
        state['x_grad'] = y_measured.squeeze(0)

    def __call__(self, model, state):
        if 'x_grad' in state:
            state['x_param'].grad = state['x_grad']
            state['optim'].step()
        return state['x_param'].detach_()[None]


@dataclass
class SGDGPOptimizer(SGDOptimizer):
    lr: float = 5e-2
    momentum: float = 0.9

    _measure_goals = (MeasureGoal.LEFT, MeasureGoal.RIGHT)

    def __post_init__(self):
        self.shifts = torch.tensor([-.5, .5]) * math.pi

    def init_state(self, model):
        state = super().init_state(model)
        state['grad'] = False
        return state

    def update(self, model, state, x_measured, y_measured, y_var=None):
        try:
            with self.update_logger(model, state, x_measured, y_measured, y_var):
                model.update(
                    x_measured,
                    y_measured,
                    y_var,
                    meta=[
                        MeasureMeta(step=state['step'], goal=goal, readout=state['n_readout'])
                        for goal in islice(cycle(state['goal']), len(x_measured))
                    ]
                )
        finally:
            state['y_start'] = model.posterior(state['x_start'][None], diag=True).mean
            if state['y_start'].item() < state['y_best'].item():
                state['y_best'] = state['y_start']
                state['x_best'] = state['x_start']

    @contextmanager
    def update_logger(self, model, state, x_measured, y_measured, y_var):
        '''Log prior to and after updating model with measured points.'''
        if not self.debug:
            yield
            return
        if self.debug.has('grad'):
            pivot = state['x_start'][None].clone().requires_grad_()

        def vargrad_items(key):
            if self.debug.has('grad'):
                grad = torch.autograd.grad(model.posterior(pivot, diag=True).mean.sum(), pivot)[0].squeeze(0)
                kdist = model.posterior_grad(state['x_start'][None], diag=True)
                kgrad = kdist.mean.squeeze(-1)
                kgradvar = kdist.var.squeeze(-1)
            return {
                # variance
                **({} if not self.debug.has('var') else {
                    f'pivot_{key}_var': model.posterior(state['x_start'][None], diag=True).var.item(),
                    f'gp_{key}_var': list(model.posterior(x_measured, diag=True).var),
                }),
                # gradients
                **({} if not self.debug.has('grad') else {
                    # autograd grad
                    f'pivot_{key}_grad_2': ((grad ** 2).mean() ** .5).item(),
                    f'pivot_{key}_grad_1': grad.abs().mean().item(),
                    # kernel-grad grad
                    f'pivot_{key}_kgrad_2': ((kgrad ** 2).mean() ** .5).item(),
                    f'pivot_{key}_kgrad_1': kgrad.abs().mean().item(),
                    f'pivot_{key}_kgradvar_2': (kgradvar.mean() ** .5).item(),
                }),
            }

        state['log'].update({
            **vargrad_items('prior'),
        })
        try:
            yield
        finally:
            state['log'].update({
                **({} if not self.debug.has('meas') else {
                    'x_meas': list(x_measured),
                    'y_meas': list(y_measured),
                    'y_var': list(y_var),
                }),
                **vargrad_items('post'),
            })

    def choose_pivot(self, model, state):
        '''Do a gradient descend step.'''
        state['x_param'].grad = model.posterior_grad(state['x_param'].detach_()[None], diag=True).mean.squeeze(-1)
        state['optim'].step()
        return state['x_param'].detach_()

    def choose_measurements(self, model, state):
        '''Choose the next shifts from the pivot along the axis.'''
        return torch.cat([
            make_shift(state['x_start'], self.shifts, axis)
            for axis in product(*(range(n) for n in state['x_start'].shape))
        ])

    def choose_readout(self, model, state):
        '''Choose the number of readouts to be used to measure the next shifts from the pivot along the axis.'''
        return state['n_readout']

    def finalize_measurements(self, model, state):
        ''''''
        n_readout = state['n_readout']
        x_meas = state['x_meas']
        if isinstance(n_readout, int):
            n_readout = (n_readout,)
        if len(n_readout) == 1:
            n_readout = n_readout * len(x_meas)

        indices, n_readout, goals = zip(*(
            (index, readout, goal)
            for index, (readout, goal) in enumerate(zip(n_readout, state['goal'])) if readout > 0
        ))

        return x_meas[list(indices)], n_readout, goals

    def __call__(self, model, state):
        state['x_start'] = self.choose_pivot(model, state)

        state['x_meas'] = self.choose_measurements(model, state)
        state['n_readout'] = self.choose_readout(model, state)
        return state['x_meas']

@dataclass
class GRADCOREOptimizer(SGDOptimizer):
    _measure_goals = (MeasureGoal.LEFT, MeasureGoal.RIGHT)

    lr: float = 1e-2
    momentum: float = 0.9

    gridsize: int = 100
    corethresh: float = 1.0
    corethresh_width: int = 10
    corethresh_scale: float = 1.
    coremin_scale: float = 1e-5
    corethresh_shift: float = 0.0
    corethresh_power: float = 1.
    single_readout_var: float = None
    readout_strategy: str = 'max'
    corethresh_strategy: str = 'linreg'
    coreref: str = 'cur'
    coremetric: str = 'std'
    coremargin: float = 1.0
    coremomentum: int = 0
    pnorm: float = 2.

    def __post_init__(self):
        self._coreref = {
            'cur': 'y_start',
            'best': 'y_best'
        }[self.coreref]
        self.shifts = torch.tensor([-.5, .5]) * math.pi

    def init_state(self, model):
        state = super().init_state(model)
        state['single_readout_var'] = (
            model.y_var_default * state['n_readout']
            if self.single_readout_var is None else
            self.single_readout_var
        )
        if self.coremetric == 'std':
            state['min_corethresh'] = (state['single_readout_var'] / 8192) ** .5 * self.coremin_scale
            state['corethresh'] = self.corethresh
        elif self.coremetric == 'readout':
            state['min_corethresh'] = (state['single_readout_var'] / self.coremin_scale) ** .5
            state['corethresh'] = (state['single_readout_var'] / self.corethresh) ** .5
        state['initial_readout'] = state['n_readout']
        state['corethresh_scale'] = self.corethresh_scale
        state['coremargin'] = self.coremargin
        if self.corethresh_shift < 0.0:
            raise ValueError(f'Corethresh shift value cannot be negative! Supplied value is: {self.corethresh_shift}.')
        state['corethresh_shift'] = self.corethresh_shift
        state['scheduler'] = ReduceLROnPlateau(
            state['optim'],
            factor=0.99,
            patience=10,
            threshold=1e-4,
            min_lr=1e-7)
        state['grad'] = False
        return state

    def update(self, model, state, x_measured, y_measured, y_var=None):
        try:
            if y_var is None:
                if 'single_readout_var' in state:
                    n_readout = state['n_readout']
                    if isinstance(n_readout, int):
                        n_readout = (n_readout,)
                    y_var = state['single_readout_var'] / torch.tensor(n_readout, dtype=x_measured.dtype)
                else:
                    y_var = torch.full_like(y_measured, model.y_var_default)

            with self.update_logger(model, state, x_measured, y_measured, y_var):
                model.update(
                    x_measured,
                    y_measured,
                    y_var=y_var,
                    meta=[
                        MeasureMeta(step=state['step'], goal=goal, readout=int(readout))
                        for goal, readout in zip(islice(cycle(state['goal']), len(x_measured)), torch.masked_select(state['n_readout'], state['n_readout'] != 0))
                    ]
                )

        finally:
            state['y_start'] = model.posterior(state['x_start'][None], diag=True).mean
            state['y_start_std'] = model.posterior(state['x_start'][None], diag=True).std
            if state['y_start'].item() < state['y_best'].item():
                state['y_best'] = state['y_start']
                state['x_best'] = state['x_start']

    @contextmanager
    def update_logger(self, model, state, x_measured, y_measured, y_var):
        '''Log prior to and after updating model with measured points.'''
        if not self.debug:
            yield
            return
        if self.debug.has('grad'):
            pivot = state['x_start'][None].clone().requires_grad_()

        def vargrad_items(key):
            if self.debug.has('grad'):
                grad = torch.autograd.grad(model.posterior(pivot, diag=True).mean.sum(), pivot)[0].squeeze(0)
                kdist = model.posterior_grad(state['x_start'][None], diag=True)
                kgrad = kdist.mean.squeeze(-1)
                kgradvar = kdist.var.squeeze(-1)
            return {
                # variance
                **({} if not self.debug.has('var') else {
                    f'pivot_{key}_var': model.posterior(state['x_start'][None], diag=True).var.item(),
                    f'gp_{key}_var': list(model.posterior(x_measured, diag=True).var),
                }),
                # gradients
                **({} if not self.debug.has('grad') else {
                    # autograd grad
                    f'pivot_{key}_grad_2': ((grad ** 2).mean() ** .5).item(),
                    f'pivot_{key}_grad_1': grad.abs().mean().item(),
                    # kernel-grad grad
                    f'pivot_{key}_kgrad_2': ((kgrad ** 2).mean() ** .5).item(),
                    f'pivot_{key}_kgrad_1': kgrad.abs().mean().item(),
                    f'pivot_{key}_kgradvar_2': (kgradvar.mean() ** .5).item(),
                }),
            }

        state['log'].update({
            **vargrad_items('prior'),
        })
        try:
            yield
        finally:
            state['log'].update({
                **({} if not self.debug.has('meas') else {
                    'x_meas': list(x_measured),
                    'y_meas': list(y_measured),
                    'y_var': list(y_var),
                }),
                **vargrad_items('post'),
            })

    def choose_pivot(self, model, state):
        '''Do a gradient descend step.'''
        state['x_param'].grad = model.posterior_grad(state['x_param'].detach_()[None], diag=True).mean.squeeze(-1)
        state['optim'].step()
        if state['log']:
            state['scheduler'].step(f"{state['log']['y_true']:.3f}")
        return state['x_param'].detach_()

    def choose_measurements(self, model, state):
        '''Choose the next shifts from the pivot along the axis.'''
        return torch.cat([
            make_shift(state['x_start'], self.shifts, axis)
            for axis in product(*(range(n) for n in state['x_start'].shape))
        ])

    def _readout_max(self, model, state):
        '''Simple readout strategy based on kappa as upper bound.'''
        return int(state['single_readout_var'] / state['corethresh'] ** 2)


    def _readout_core(self, model, state):
        '''Readout strategy based on predictive posterior variance.'''
        x_test = state['x_param']
        x_meas, n_readout_all = [], []
        grad_pivot = model.posterior_grad(x_test[None], diag=True)
        for axis in product(*(range(n) for n in state['x_start'].shape)):
            x_cand_meas = make_shift(state['x_start'], self.shifts, axis)
            if (grad_pivot[axis].std <= state['corethresh'] * state['coremargin']):
                n_readout_all.extend(torch.zeros([2,1]))
                continue
            x_meas.append(x_cand_meas)
            y_cand_var = logspace(state['single_readout_var'], state['corethresh'] ** 2, 100, 200)
            peek_distr = model.peek_posterior_grad(x_cand_meas, x_test[None], y_var=y_cand_var[:-1, None],diag=True)

            index = torch.searchsorted(
                (peek_distr.std[axis] <= state['corethresh'] * state['coremargin']).all(axis=-1).int(), True
            )
            y_var_all = y_cand_var[index].item()
            d_readout = int(state['single_readout_var'] / y_var_all)
            n_readout_all.append(torch.tensor([d_readout, d_readout]))

        if len(x_meas) == 0:
            return torch.zeros([80,8,5]), torch.cat(n_readout_all, dim=0)
        return torch.cat(x_meas, dim=0), torch.cat(n_readout_all, dim=0)

    def _corethresh_linreg(self, model, state):
        '''Corethresh (kappa) strategy based on linear regression of the corewidth-th last energy values.'''
        if self.corethresh_width > 0:
            state.setdefault('energy_log', []).append(state[self._coreref].item())
            if len(state['energy_log']) > self.corethresh_width:
                if 'corethresh_pinv' not in state:
                    state['corethresh_pinv'] = torch.linalg.pinv(torch.stack([
                        torch.arange(float(self.corethresh_width)),
                        torch.ones(self.corethresh_width),
                    ], axis=-1))
                slope = (
                    state['corethresh_pinv']
                    @ torch.tensor(state['energy_log'][-self.corethresh_width:])
                )[0].abs().item()
                # next_energy = params @ torch.tensor([float(self.corethresh_width), 1.])

                state['corethresh'] = max(
                    state['min_corethresh'],
                    slope ** self.corethresh_power * state['corethresh_scale']
                )

    def _corethresh_grad(self, model, state, cheat=False):
        '''Corethresh (kappa) strategy based on the gradient norm at the current pivot point.'''
        pivot = state['x_start'][None].clone().requires_grad_()
        if cheat:
            y_true = state['true_fn'](pivot, n_readout=0)
            grad = torch.mean(
                torch.autograd.grad(y_true.sum(), pivot)[0].abs() ** self.pnorm
            ).item()
        else:
            grad = torch.mean(
                #torch.autograd.grad(model.posterior(pivot, diag=True).mean.sum(), pivot)[0].abs() ** self.pnorm
                model.posterior_grad(pivot[None], diag=True).mean[:,:,:,0].abs() ** self.pnorm
            ).item()
        buf = state.setdefault('grad_log', deque(maxlen=self.corethresh_width))
        buf.append(grad)
        if len(buf) >= self.corethresh_width:
            slope = np.mean(buf) ** (1. / self.pnorm)
            state['corethresh'] = max(
                state['min_corethresh'],
                (slope - state['corethresh_shift']) ** self.corethresh_power * state['corethresh_scale']
            )

    def choose_measurement_readout(self, model, state):
        '''Choose the number of readouts to be used to measure the next shifts from the pivot along the axis.'''

        corethresh_method_name = f'_corethresh_{self.corethresh_strategy}'
        if not hasattr(self, corethresh_method_name):
            raise RuntimeError(f'No such corethresh strategy: \'{self.corethresh_strategy}\'')

        getattr(self, corethresh_method_name)(model, state)

        readout_method_name = f'_readout_{self.readout_strategy}'
        if not hasattr(self, readout_method_name):
            raise RuntimeError(f'No such readout strategy: \'{self.readout_strategy}\'')

        # bound minimum number of shots to 1 for now
        x_meas, n_readout = getattr(self, readout_method_name)(model, state)
        if isinstance(n_readout, int):
            n_readout = (n_readout,)
        # return tuple(max(1, n) for n in n_readout)
        return x_meas, n_readout

    def finalize_measurements(self, model, state):
        #this finalizes the measurement of x_shifts for all the necessary axis with appropriate n_readout
        ''''''
        n_readout = state['n_readout']
        x_meas = state['x_meas']
        if isinstance(n_readout, int):
            n_readout = (n_readout,)
        if len(n_readout) == 1:
            n_readout = n_readout * len(x_meas)

        indices, n_readout, goals = zip(*(
            (index, readout, goal)
            for index, (readout, goal) in enumerate(zip(n_readout, state['goal'])) if readout > 0
        ))

        return x_meas[list(indices)], n_readout, goals

    def __call__(self, model, state):
        state['x_start'] = self.choose_pivot(model, state)

        # if state['y_start'].item() < state['y_best'].item():
        #     state['y_best'] = state['y_start']

        state['x_meas'], state['n_readout'] = self.choose_measurement_readout(model, state)
        # state['x_meas'], state['n_readout'], state['goal'] = self.finalize_measurements(model, state)
        return state['x_meas']

def make_shift(pivot, shifts, axis, zero_align=False):
    """Create shifted tensors of pivot by shifts along axis."""
    if zero_align:
        pivot = pivot.clone()
        pivot[axis] = 0.
    outshape = (*shifts.shape, *pivot.shape)
    coords = torch.cartesian_prod(*[torch.arange(n) for n in shifts.shape], *[torch.tensor((n,)) for n in axis])
    return (
        pivot[(None,) * shifts.ndim].expand(outshape)
        + torch.sparse_coo_tensor(coords.t(), shifts.flatten(), outshape)
    ) % math.tau


class LeastSquaresWave:
    def __init__(self, shifts):
        self._shifts = shifts
        self._pinv = None

    @property
    def pinv(self):
        if self._pinv is None:
            self._pinv = self.fit(self._shifts)
        return self._pinv

    @property
    def shifts(self):
        return self._shifts

    @shifts.setter
    def shifts(self, value):
        self._shifts = value
        self._pinv = None

    @staticmethod
    def fit(shifts):
        points = torch.tensor([0., *shifts], dtype=torch.float64)
        data = torch.stack([points ** 0., points.cos(), points.sin()], dim=1)
        return torch.linalg.pinv(data)

    def solve(self, pivot, y_pivot, y_shifts, axis):
        c0, c1, c2 = self.pinv @ torch.tensor([y_pivot, *y_shifts], dtype=torch.float64)
        # atan2(c2, c1) + pi is the argmin of c1 * cos x + c2 * sin x
        theta = torch.atan2(c2, c1) + math.pi

        shift = torch.sparse_coo_tensor(list(zip(axis)), theta, pivot.shape)
        return (
            (pivot + shift) % math.tau,
            c0 + c1 * theta.cos() + c2 * theta.sin(),
        )

    def __call__(self, y_pivot, y_shifts, x_cand):
        c0, c1, c2 = self.pinv @ torch.tensor([y_pivot, *y_shifts], dtype=torch.float64)
        return c0 + c1 * x_cand.cos() + c2 * x_cand.sin()


@dataclass
class LineSearchOptimizer(Optimizer):
    def update(self, model, state, x_measured, y_measured, y_var=None):
        try:
            if y_var is None:
                if 'single_readout_var' in state:
                    n_readout = state['n_readout']
                    if isinstance(n_readout, int):
                        n_readout = (n_readout,)
                    y_var = state['single_readout_var'] / torch.tensor(n_readout, dtype=x_measured.dtype)
                else:
                    y_var = torch.full_like(y_measured, model.y_var_default)

            with self.update_logger(model, state, x_measured, y_measured, y_var):
                model.update(
                    x_measured,
                    y_measured,
                    y_var=y_var,
                    meta=[
                        MeasureMeta(step=state['step'], goal=goal, readout=readout)
                        for goal, readout in zip(state['goal'], state['n_readout'])
                    ]
                )

        finally:
            state['y_start'] = model.posterior(state['x_start'][None], diag=True).mean

    @contextmanager
    def update_logger(self, model, state, x_measured, y_measured, y_var):
        '''Log prior to and after updating model with measured points.'''
        if not self.debug:
            yield
            return
        if self.debug.has('grad'):
            pivot = state['x_start'][None].clone().requires_grad_()

        def vargrad_items(key):
            if self.debug.has('grad'):
                grad = torch.autograd.grad(model.posterior(pivot, diag=True).mean.sum(), pivot)[0].squeeze(0)
                kdist = model.posterior_grad(state['x_start'][None], diag=True)
                kgrad = kdist.mean.squeeze(-1)
                kgradvar = kdist.var.squeeze(-1)
            return {
                # variance
                **({} if not self.debug.has('var') else {
                    f'pivot_{key}_var': model.posterior(state['x_start'][None], diag=True).var.item(),
                    f'gp_{key}_var': list(model.posterior(x_measured, diag=True).var),
                    f'gp_{key}_var_max': model.posterior(
                        make_shift(state['x_start'], torch.linspace(0, math.tau, 100), state['k_best']),
                        diag=True,
                    ).var.amax(),
                }),
                # gradients
                **({} if not self.debug.has('grad') else {
                    # autograd grad
                    f'pivot_{key}_grad_2': ((grad ** 2).mean() ** .5).item(),
                    f'pivot_{key}_grad_1': grad.abs().mean().item(),
                    f'pivot_{key}_grad_dir': (grad[state['k_best']]).item(),
                    f'pivot_{key}_grad_ldir': (grad[shape_step(state['k_best'], grad.shape, -1)]).item(),
                    # kernel-grad grad
                    f'pivot_{key}_kgrad_2': ((kgrad ** 2).mean() ** .5).item(),
                    f'pivot_{key}_kgrad_1': kgrad.abs().mean().item(),
                    f'pivot_{key}_kgrad_dir': (kgrad[state['k_best']]).item(),
                    f'pivot_{key}_kgrad_ldir': (kgrad[shape_step(state['k_best'], kgrad.shape, -1)]).item(),
                    f'pivot_{key}_kgradvar_2': (kgradvar.mean() ** .5).item(),
                }),
            }

        state['log'].update({
            **vargrad_items('prior'),
        })
        try:
            yield
        finally:
            state['log'].update({
                **({} if not self.debug.has('meas') else {
                    'x_meas': list(x_measured),
                    'y_meas': list(y_measured),
                    'y_var': list(y_var),
                }),
                **vargrad_items('post'),
            })

    def choose_axis(self, model, state):
        '''Choose the next axis, setting ``state['k_best']``.'''
        raise NotImplementedError()

    def choose_measurements(self, model, state):
        '''Choose the next shifts from the pivot along the axis.'''
        raise NotImplementedError()

    def choose_readout(self, model, state):
        '''Choose the number of readouts to be used to measure the next shifts from the pivot along the axis.'''
        return self.n_readout

    def choose_pivot(self, model, state):
        '''Given the shift candidates, find the best point on the line.'''
        raise NotImplementedError()

    def require_stabilize(self, model, state):
        '''Given the pivot point, decide whether the current best point should be measured for stabilization.'''
        raise NotImplementedError()

    def finalize_measurements(self, model, state):
        ''''''
        n_readout = state['n_readout']
        x_meas = state['x_meas']
        if isinstance(n_readout, int):
            n_readout = (n_readout,)
        if len(n_readout) == 1:
            n_readout = n_readout * len(x_meas)

        measurements = list(zip(*(
            (index, readout, goal)
            for index, (readout, goal) in enumerate(zip(n_readout, state['goal'])) if readout
        )))
        if measurements:
            indices, n_readout, goals = measurements
        else:
            indices, n_readout, goals = [], [], []

        return x_meas[list(indices)], torch.tensor(n_readout), goals

    def __call__(self, model, state):
        state['goal'] = self._measure_goals
        if 'k_best' in state:
            # update x_start using previous axis, x_start and x_shift
            state['x_pivot'], state['y_pivot'] = self.choose_pivot(model, state)

            state['x_start'], state['y_start'] = state['x_pivot'], state['y_pivot']

            if state['y_start'].item() < state['y_best'].item():
                state['y_best'] = state['y_start']
        state['k_best'] = self.choose_axis(model, state)
        state['x_meas'] = self.choose_measurements(model, state)

        if self.require_stabilize(model, state):
            state['x_meas'] = torch.cat((state['x_meas'], state['x_pivot'][None]))
            state['goal'] = (*self._measure_goals, MeasureGoal.PIVOT)

        state['n_readout'] = self.choose_readout(model, state)
        state['x_meas'], state['n_readout'], state['goal'] = self.finalize_measurements(model, state)
        return state['x_meas']


@dataclass
class SMOOptimizer(LineSearchOptimizer):
    _shift_modes = {
        '2pi3': math.tau / 3.,
        'pi3': math.pi / 3.,
        'pi2': math.pi / 2.,
        '5pi8': 5 * math.pi / 8.,
        '3pi8': 3 * math.pi / 8.,
        'pi6': math.pi / 6.,
    }
    _measure_goals = (MeasureGoal.LEFT, MeasureGoal.RIGHT)

    stabilize_interval: int = 0
    shift_mode: str = '2pi3'

    def __post_init__(self):
        shift = self._shift_modes[self.shift_mode]
        self.lsw = LeastSquaresWave((-shift, shift))

    def choose_axis(self, model, state):
        '''Choose the next axis, setting ``state['k_best']``.'''
        if 'k_best' not in state:
            return (0,) * state['x_start'].ndim
        return shape_step(state['k_best'], state['x_start'].shape)

    def choose_measurements(self, model, state):
        '''Choose the next shifts from the pivot along the axis.'''
        return make_shift(state['x_start'], torch.tensor(self.lsw.shifts), state['k_best'])

    def choose_pivot(self, model, state):
        '''Given the shift candidates, find the best point on the line.'''
        y_start = model.posterior(state['x_start'][None], diag=True).mean
        x_pairs = make_shift(state['x_start'], torch.tensor(self.lsw.shifts), state['k_best'])
        y_pairs = model.posterior(x_pairs, diag=True).mean
        return self.lsw.solve(state['x_start'], y_start, y_pairs, state['k_best'])

    def require_stabilize(self, model, state):
        '''Given the pivot point, decide whether the current best point should be measured for stabilization.'''
        return state['step'] and self.stabilize_interval and not (state['step'] % self.stabilize_interval)


@dataclass
class NFTOptimizer(SMOOptimizer):
    '''Sequentially finds the line-wise minimum according to NFT. Does not use or update the GP.'''
    def update(self, model, state, x_measured, y_measured, y_var=None):
        if len(x_measured) == 1:
            if not torch.allclose(state['x_start'], x_measured):
                raise RuntimeError(
                    'Measured exactly one sample, expected stabilization, but point did not match pivot!'
                )
            state['y_start'] = y_measured
        elif len(x_measured) == 2:
            if not torch.allclose(state['x_meas'], x_measured):
                raise RuntimeError('Measured exactly two samples, expected shift, but points did not match!')
            state['x_pairs'] = x_measured
            state['y_pairs'] = y_measured
            state['x_measured'] = x_measured
            state['y_measured'] = y_measured
        elif len(x_measured) == 3:
            if not torch.allclose(state['x_meas'], x_measured):
                raise RuntimeError('Measured exactly two samples, expected shift, but points did not match!')
            state['x_pairs'] = x_measured[:2]
            state['y_pairs'] = y_measured[:2]
            state['x_start'] = x_measured[2]
            state['y_start'] = y_measured[2]
        else:
            raise RuntimeError(f'Expected 1 or 2 measured samples, got {len(x_measured)}!')

    def choose_pivot(self, model, state):
        '''Given the shift candidates, find the best point on the line.'''
        return self.lsw.solve(state['x_start'], state['y_start'], state['y_pairs'], state['k_best'])


@dataclass
class SUBSCOREOptimizer(SMOOptimizer):
    _measure_goals = (MeasureGoal.LEFT, MeasureGoal.PIVOT, MeasureGoal.RIGHT)

    gridsize: int = 100
    corethresh: float = 1.0
    corethresh_width: int = 10
    corethresh_scale: float = 1.
    coremin_scale: float = 1e-5
    corethresh_shift: float = 0.0
    corethresh_power: float = 1.
    single_readout_var: float = None
    readout_strategy: str = 'max'
    corethresh_strategy: str = 'linreg'
    coreref: str = 'cur'
    coremetric: str = 'std'
    coremargin: float = 1.0
    coremomentum: int = 0
    pnorm: float = 2.

    def __post_init__(self):
        super().__post_init__()
        self.lsw.shifts = (self.lsw.shifts[0], 0., self.lsw.shifts[1])
        self._coreref = {
            'cur': 'y_start',
            'best': 'y_best'
        }[self.coreref]

    def init_state(self, model):
        state = super().init_state(model)
        state['single_readout_var'] = (
            model.y_var_default * state['n_readout']
            if self.single_readout_var is None else
            self.single_readout_var
        )
        if self.coremetric == 'std':
            state['min_corethresh'] = (state['single_readout_var'] / 8192) ** .5 * self.coremin_scale
            state['corethresh'] = self.corethresh
        elif self.coremetric == 'readout':
            state['min_corethresh'] = (state['single_readout_var'] / self.coremin_scale) ** .5
            state['corethresh'] = (state['single_readout_var'] / self.corethresh) ** .5
        state['initial_readout'] = state['n_readout']
        state['corethresh_scale'] = self.corethresh_scale
        state['coremargin'] = self.coremargin
        if self.corethresh_shift < 0.0:
            raise ValueError(f'Corethresh shift value cannot be negative! Supplied value is: {self.corethresh_shift}.')
        state['corethresh_shift'] = self.corethresh_shift
        return state

    def _readout_max(self, model, state):
        '''Simple readout strategy based on kappa as upper bound.'''
        return int(state['single_readout_var'] / state['corethresh'] ** 2)

    def _readout_core(self, model, state, retry_center=False):
        '''Readout strategy based on predictive posterior variance.'''
        grid_candidates = torch.linspace(0, math.tau, self.gridsize + 1)[:-1]
        x_tests = make_shift(
            state['x_start'],
            grid_candidates,
            state['k_best']
        )

        # y_var = torch.linspace(state['single_readout_var'] ** .5, state['corethresh'], 100) ** 2
        y_var = logspace(state['single_readout_var'], state['corethresh'] ** 2, 100, 200)
        x_meas = state['x_meas']

        peek_distr = model.peek_posterior(x_meas[None], x_tests, y_var=y_var[:-1, None], diag=True)

        index = torch.searchsorted(
            (peek_distr.std <= state['corethresh'] * state['coremargin']).all(axis=-1).int(), True
        )
        y_var_all = y_var[index].item()
        if not retry_center:
            return int(state['single_readout_var'] / y_var_all)

        if self.lsw.shifts[1] != 0.:
            raise RuntimeError('Shift 1 is not in the center!')

        y_var_shifts = torch.full((y_var.shape[0], len(self.lsw.shifts)), y_var_all, dtype=y_var.dtype)
        y_var_shifts[:, 1] = y_var

        peek_distr_shifts = model.peek_posterior(x_meas[None], x_tests, y_var=y_var_shifts[:-1], diag=True)
        index_shifts = torch.searchsorted(
            (peek_distr_shifts.std <= state['corethresh'] * state['coremargin']).all(axis=-1).int(), True
        )

        n_readout = tuple(
            int((state['single_readout_var'] / y_v).item())
            for y_v in y_var_shifts[index_shifts]
        )
        return torch.tensor(n_readout)

    def _readout_center(self, model, state):
        return self._readout_core(model, state, retry_center=True)

    def _corethresh_last(self, model, state):
        '''Corethresh (kappa) strategy based on the corewidth-th last energy (LEGACY).'''
        if self.corethresh_width > 0:
            state.setdefault('energy_log', []).append(state[self._coreref].item())
            if len(state['energy_log']) > self.corethresh_width:
                slope = (
                    state['energy_log'][-self.corethresh_width - 1] - state['energy_log'][-1]
                ) / self.corethresh_width
                state['corethresh'] = max(
                    state['min_corethresh'],
                    slope ** self.corethresh_power * state['corethresh_scale']
                )

    def _corethresh_lastabs(self, model, state):
        '''Corethresh (kappa) strategy based on the corewidth-th last energy with an absolute upper readout bound.'''
        if self.corethresh_width > 0:
            state.setdefault('energy_log', []).append(state[self._coreref].item())
            if len(state['energy_log']) > self.corethresh_width:
                slope = (
                    state['energy_log'][-self.corethresh_width - 1] - state['energy_log'][-1]
                ) / self.corethresh_width
                state['corethresh'] = max(
                    state['min_corethresh'],
                    slope ** self.corethresh_power * state['corethresh_scale']
                )

    def _corethresh_linreg(self, model, state):
        '''Corethresh (kappa) strategy based on linear regression of the corewidth-th last energy values.'''
        if self.corethresh_width > 0:
            state.setdefault('energy_log', []).append(state[self._coreref].item())
            if len(state['energy_log']) > self.corethresh_width:
                if 'corethresh_pinv' not in state:
                    state['corethresh_pinv'] = torch.linalg.pinv(torch.stack([
                        torch.arange(float(self.corethresh_width)),
                        torch.ones(self.corethresh_width),
                    ], axis=-1))
                slope = (
                    state['corethresh_pinv']
                    @ torch.tensor(state['energy_log'][-self.corethresh_width:])
                )[0].abs().item()
                # next_energy = params @ torch.tensor([float(self.corethresh_width), 1.])

                state['corethresh'] = max(
                    state['min_corethresh'],
                    slope ** self.corethresh_power * state['corethresh_scale']
                )

    def _corethresh_avg(self, model, state):
        '''Corethresh (kappa) strategy based on linear regression of the corewidth-th last energy values.'''
        if self.corethresh_width > 0:
            state.setdefault('energy_log', []).append(state[self._coreref].item())
            if len(state['energy_log']) > self.corethresh_width:
                energies = torch.tensor(state['energy_log'][-self.corethresh_width:])
                slope = (energies[1:] - energies[:-1]).mean().abs()
                state['corethresh'] = max(
                    state['min_corethresh'],
                    slope ** self.corethresh_power * state['corethresh_scale']
                )

    def _corethresh_grad(self, model, state, cheat=False):
        '''Corethresh (kappa) strategy based on the gradient norm at the current pivot point.'''
        pivot = state['x_start'][None].clone().requires_grad_()
        if cheat:
            y_true = state['true_fn'](pivot, n_readout=0)
            grad = torch.mean(
                torch.autograd.grad(y_true.sum(), pivot)[0].abs() ** self.pnorm
            ).item()
        else:
            grad = torch.mean(
                #torch.autograd.grad(model.posterior(pivot, diag=True).mean.sum(), pivot)[0].abs() ** self.pnorm
                model.posterior_grad(pivot[None], diag=True).mean[:,:,:,0].abs() ** self.pnorm
            ).item()
        buf = state.setdefault('grad_log', deque(maxlen=self.corethresh_width))
        buf.append(grad)
        if len(buf) >= self.corethresh_width:
            slope = np.mean(buf) ** (1. / self.pnorm)
            state['corethresh'] = max(
                state['min_corethresh'],
                (slope - state['corethresh_shift']) ** self.corethresh_power * state['corethresh_scale']
            )

    def _corethresh_cheatgrad(self, model, state):
        '''Corethresh (kappa) strategy based on the cheated true gradient norm at the current pivot point.'''
        if 'true_fn' not in state:
            raise RuntimeError('Tried cheating without explicitly allowing!')
        self._corethresh_grad(model, state, cheat=True)

    def choose_readout(self, model, state):
        '''Choose the number of readouts to be used to measure the next shifts from the pivot along the axis.'''

        corethresh_method_name = f'_corethresh_{self.corethresh_strategy}'
        if not hasattr(self, corethresh_method_name):
            raise RuntimeError(f'No such corethresh strategy: \'{self.corethresh_strategy}\'')

        getattr(self, corethresh_method_name)(model, state)

        if self.coremomentum:
            if 'corethresh_log' not in state:
                state['corethresh_log'] = [state['corethresh']] * (self.coremomentum - 1)
            state['corethresh_log'].append(state['corethresh'])
            state['corethresh'] = sum(state['corethresh_log'][-self.coremomentum:]) / self.coremomentum

        readout_method_name = f'_readout_{self.readout_strategy}'
        if not hasattr(self, readout_method_name):
            raise RuntimeError(f'No such readout strategy: \'{self.readout_strategy}\'')

        # bound minimum number of shots to 1 for now
        n_readout = getattr(self, readout_method_name)(model, state)
        if isinstance(n_readout, int):
            n_readout = (n_readout,)
        # return tuple(max(1, n) for n in n_readout)
        return n_readout


@dataclass
class EMICOREOptimizer(SMOOptimizer):
    pairsize: int = 20
    gridsize: int = 100
    samplesize: int = 100
    corethresh: float = 1.0
    corethresh_width: int = 10
    corethresh_scale: float = 1.
    corethresh_power: float = 1.
    corethresh_shift: float = 0.
    coremin_scale: float = 1e-5
    coremargin: float = 1.0
    core_trials: int = 10
    single_readout_var: float = None
    corethresh_strategy: str = 'linreg'
    coreref: str = 'cur'
    coremetric: str = 'std'
    coremargin: float = 1.0
    coremomentum: int = 0
    smo_steps: int = 100
    smo_axis: int = 0
    pivot_steps: int = 0
    pivot_scale: float = 1.0
    pivot_mode: str = 'smo'
    pnorm: float = 2.

    def __post_init__(self):
        super().__post_init__()
        self._coreref = {
            'cur': 'y_start',
            'best': 'y_best'
        }[self.coreref]

    def init_state(self, model):
        state = super().init_state(model)
        state['single_readout_var'] = (
            model.y_var_default * state['n_readout']
            if self.single_readout_var is None else
            self.single_readout_var
        )
        if self.coremetric == 'std':
            state['min_corethresh'] = (state['single_readout_var'] / 8192) ** .5 * self.coremin_scale
            state['corethresh'] = self.corethresh
        elif self.coremetric == 'readout':
            state['min_corethresh'] = (state['single_readout_var'] / self.coremin_scale) ** .5
            state['corethresh'] = (state['single_readout_var'] / self.corethresh) ** .5
        state['initial_readout'] = state['n_readout']
        state['corethresh_scale'] = self.corethresh_scale
        state['coremargin'] = self.coremargin
        if self.corethresh_shift < 0.0:
            raise ValueError(f'Corethresh shift value cannot be negative! Supplied value is: {self.corethresh_shift}.')
        state['corethresh_shift'] = self.corethresh_shift
        return state

    def _emicore(self, model, state, axis):
        single_candidates = torch.linspace(0, math.tau, self.pairsize + 1)[:-1]
        pair_candidates = torch.cartesian_prod(*(single_candidates,) * 2)
        grid_candidates = torch.linspace(0, math.tau, self.gridsize + 1)[:-1]

        x_pairs = make_shift(
            state['x_start'],
            torch.cat((torch.tensor(self.lsw.shifts)[None], pair_candidates)),
            axis,
        )
        x_tests = make_shift(
            state['x_start'],
            grid_candidates,
            axis
        )

        peek_distr = model.peek_posterior(x_pairs, x_tests, diag=True)
        post_distr = model.posterior(x_tests, diag=True)
        coremask = peek_distr.std < state.get('corethresh', self.corethresh)
        state['corerate'] = (coremask.sum() / coremask.numel()).item()
        coremask[:, 0] = True
        core = torch.tensor([torch.inf, 1.])[coremask * 1]

        best_distr = model.posterior(state['x_start'][None], diag=True)
        samplesize = (self.pairsize ** 2 + 1, self.samplesize)
        maximp = (
            best_distr.sample(samplesize).squeeze(-1) - (post_distr.sample(samplesize) * core[..., None, :]).amin(-1)
        ).clip_(min=0.).nan_to_num_(posinf=0.0).mean(-1)
        return x_pairs, maximp

    def _corethresh_grad(self, model, state, cheat=False):
        '''Corethresh (kappa) strategy based on the gradient norm at the current pivot point.'''
        pivot = state['x_start'][None].clone().requires_grad_()
        if cheat:
            y_true = state['true_fn'](pivot, n_readout=0)
            grad = torch.mean(
                torch.autograd.grad(y_true.sum(), pivot)[0].abs() ** self.pnorm
            ).item()
        else:
            grad = torch.mean(
                #torch.autograd.grad(model.posterior(pivot, diag=True).mean.sum(), pivot)[0].abs() ** self.pnorm
                model.posterior_grad(pivot[None], diag=True).mean[:,:,:,0].abs() ** self.pnorm
            ).item()
        buf = state.setdefault('grad_log', deque(maxlen=self.corethresh_width))
        buf.append(grad)
        if len(buf) >= self.corethresh_width:
            slope = np.mean(buf) ** (1. / self.pnorm)
            state['corethresh'] = max(
                state['min_corethresh'],
                (slope - state['corethresh_shift']) ** self.corethresh_power * state['corethresh_scale']
            )

    def _corethresh_cheatgrad(self, model, state):
        '''Corethresh (kappa) strategy based on the cheated true gradient norm at the current pivot point.'''
        if 'true_fn' not in state:
            raise RuntimeError('Tried cheating without explicitly allowing!')
        self._corethresh_grad(model, state, cheat=True)

    def _corethresh_linreg(self, model, state):
        '''Corethresh (kappa) strategy based on linear regression of the corewidth-th last energy values.'''
        if self.corethresh_width > 0:
            state.setdefault('energy_log', []).append(state[self._coreref].item())
            if len(state['energy_log']) > self.corethresh_width:
                if 'corethresh_pinv' not in state:
                    state['corethresh_pinv'] = torch.linalg.pinv(torch.stack([
                        torch.arange(float(self.corethresh_width)),
                        torch.ones(self.corethresh_width),
                    ], axis=-1))
                slope = (
                    state['corethresh_pinv']
                    @ torch.tensor(state['energy_log'][-self.corethresh_width:])
                )[0].abs().item()
                # next_energy = params @ torch.tensor([float(self.corethresh_width), 1.])

                state['corethresh'] = max(
                    state['min_corethresh'],
                    slope ** self.corethresh_power * state['corethresh_scale']
                )

    def choose_measurements(self, model, state):
        if state['step'] < self.smo_steps:
            return super().choose_measurements(model, state)

        corethresh_method_name = f'_corethresh_{self.corethresh_strategy}'
        if not hasattr(self, corethresh_method_name):
            raise RuntimeError(f'No such corethresh strategy: \'{self.corethresh_strategy}\'')

        getattr(self, corethresh_method_name)(model, state)

        if self.coremomentum:
            if 'corethresh_log' not in state:
                state['corethresh_log'] = [state['corethresh']] * (self.coremomentum - 1)
            state['corethresh_log'].append(state['corethresh'])
            state['corethresh'] = sum(state['corethresh_log'][-self.coremomentum:]) / self.coremomentum

        x_pairs, maximp = self._emicore(model, state, state['k_best'])
        bestind = maximp.argmax()

        return x_pairs[bestind]

    def require_stabilize(self, model, state):
        stabilize = (
            super().require_stabilize(model, state)
            or (
                state.get('core_trial', 0) < self.core_trials
                and (model.posterior(state['x_pivot'][None]).std > state.get('corethresh', self.corethresh)).all()
            )
        )
        state['core_trial'] = (state.get('core_trial', 0) + 1) if stabilize else 0
        return stabilize

    def choose_axis(self, model, state):
        '''Choose the next axis, setting ``state['k_best']``.'''
        if self.smo_axis or state['step'] < self.smo_steps:
            return super().choose_axis(model, state)

        axes = [
            np.unravel_index(k_flat, state['x_start'].shape)
            for k_flat in range(state['x_start'].numel())
            if 'k_best' not in state or np.ravel_multi_index(state['k_best'], state['x_start'].shape) != k_flat
        ]

        axes, maximps = zip(*((axis, self._emicore(model, state, axis)[1].amax(0)) for axis in axes))

        return axes[torch.stack(maximps).argmax()]

    def choose_pivot(self, model, state):
        '''Given the shift candidates, find the best point on the line.'''
        k_best = state['k_best']
        x_start, y_start = super().choose_pivot(model, state)
        try:
            if state.get('corethresh', self.corethresh) <= model.y_var_default * self.pivot_scale:
                for n in range(self.pivot_steps):
                    last = (x_start, y_start)
                    if self.pivot_mode == 'smo':
                        state['k_best'] = self.choose_axis(model, state)
                        x_start, y_start = super().choose_pivot(model, state)
                        if (model.posterior(x_start[None]).std > state.get('corethresh', self.corethresh)).all():
                            return last
                    elif self.pivot_mode == 'loop':
                        k_next = state['k_best']
                        for k_dir in set(product(*(range(dim) for dim in x_start.shape))) - {k_next}:
                            state['k_best'] = k_dir
                            x_next, y_next = super().choose_pivot(model, state)
                            if (
                                (model.posterior(x_next[None]).std <= state.get('corethresh', self.corethresh)).all()
                                # and y_next < y_start
                            ):
                                x_start, y_start, k_next = x_next, y_next, k_dir
                        state['k_best'] = k_next
                        if (last[0] == x_start).all():
                            return last
                    else:
                        raise RuntimeError('No such pivot-mode!')

            return x_start, y_start
        finally:
            state['k_best'] = k_best


class ExpectedImprovement(AcquisitionFunction):
    def __call__(self, model, x_cand, y_best, step):
        f_min = y_best.item()
        distr = model.posterior(x_cand, diag=True)

        pdf = distr.pdf(f_min)
        cdf = distr.cdf(f_min)
        return (f_min - distr.mean) * cdf + distr.var * pdf
